Skip to content

Conversation

wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Aug 16, 2025

Context

  1. Remove keep a copy of GroupedExperts weight, free memory in StateDictAdapter
  2. Add illustration in README about the issue that split() might cause OOM. More details in following figure:
Screenshot 2025-08-16 at 3 41 11 PM

Test

FSDP=8 (FSDP shard dim-0), num_experts = 256

[rank0]:In _split_weight function, weights: <class 'torch.distributed.tensor.DTensor'> torch.Size([256, 2048, 7168]) (Shard(dim=0),)
[rank0]:In _split_weight function, split_weight: <class 'torch.distributed.tensor.DTensor'> torch.Size([1, 2048, 7168]) (Replicate(),)

FSDP=8 (FSDP shard dim-1), num_experts = 256

[rank0]:In _split weights, <class 'torch.distributed.tensor.DTensor'> torch.Size([256, 2048, 7168]) (Shard(dim=1),)
[rank0]:In _split split_weight, <class 'torch.distributed.tensor.DTensor'> torch.Size([1, 2048, 7168]) (Shard(dim=1),)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 16, 2025
@wwwjn wwwjn marked this pull request as ready for review August 16, 2025 22:44
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had some comments. Need @fegin 's input as well.

@@ -61,6 +61,7 @@ python scripts/checkpoint_conversion/convert_from_hf.py <hf_checkpoints_dir> <dc
Some limitations:
1. It can't be used to convert HF checkpoint on the fly using GPU DTensor, because of sharding and quantized blocks may not be aligned well and causing silent numerfical incorrectness.
2. It can't be used for weight sync to generate a state dict of bf16 because fake quantization to fp8 is applied.
3. When converting GroupedExperts weights from HF separate expert weights on-the-fly, `torch.split()` will cause huge GPU memory usage. This is because torchtitan GroupedExperts' weight has shape `(num_experts, dim1, dim2)`, and by default shard FSDP on dim-0. When we call `torch.split()` in `to_hf()` function on dim-0, this will incur and all-gather and get replicated expert memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought more about this. Even if FSDP shards on dim-1, EP will shard on dim-0 anyway. So the problem still exists. Let's discuss next week.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we perform a redistribute() before split() to ensure the expert parameter is sharded on dim-1? This redistributed, dim-1 sharded parameter will be used exclusively by the split().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With EP it's sharded on dim-0 anyway. Performing this redistribute means at least 1 comm in to_hf and at least 1 comm in from_hf.
If both EP and FSDP dim-0 sharding is used, we'll have strided sharding whose redistribute algo today may not be efficient or even correct.

Copy link
Contributor

@fegin fegin Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The redistribution algorithm should be correct, but whether it is going to be efficient, that's debatable. I think it will be more efficient than allgather as less communication is incurred even if it is not the optimal one.

There will should be no extra comm in from_hf as DCP.load will handle the resharding but this resharding can be slow for sure.

@@ -158,6 +158,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
new_key = new_abstract_key.format(layer_num, expert_num)
hf_state_dict[new_key] = split_values[expert_num].squeeze()

# Remove the GroupedExperts' weight from the state_dict to free memory
del value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for loading checkpoint synchronously, this sounds fine.
But for saving, after calling to_hf we may still need the original weights for next training steps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that's a valid concern. If a user periodically save a checkpoint in HF format, this would be a issue. I checked checkpoint.py, and it only support last_save_in_hf in _save_last_step, and we are not supporting saving HF in between

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adpater is independent of checkpoint.py in torchtitan. In RL weight sync, it will be called without checkpointing.

@wwwjn wwwjn changed the title [DSV3] Remove keep a copy of GroupedExperts weight, free memory in StateDictAdapter [WIP][DSV3] Remove keep a copy of GroupedExperts weight, free memory in StateDictAdapter Aug 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants